import java.util.*;

class Node {
    public int val;
    public List<Node> children;

    public Node() {}

    public Node(int _val) {
        val = _val;
    }

    public Node(int _val, List<Node> _children) {
        val = _val;
        children = _children;
    }
};

public class Solution559 {

    private int travel(Node node, int depth){
        int maxDepth = depth;
        if(node.children == null){return depth;}
        for (Node tmpNode : node.children) {
            maxDepth = Math.max(travel(tmpNode, depth + 1), maxDepth);
        }
        return maxDepth;
    }


    public int maxDepth(Node root) {
        if(root == null){
            return 0;
        }
        return travel(root, 1);
    }

    public static void main(String[] args) {
        Node root = new Node(1, Arrays.asList(new Node(2, Arrays.asList(new Node(5), new Node(6))), new Node(3), new Node(4)));
        Solution559 s = new Solution559();
        System.out.println(s.maxDepth(root));
    }
}
